{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Policy Laplace on Spark\n",
    "\n",
    "Spark implementation of Policy Laplace from Differentially Private Set Union [https://arxiv.org/abs/2002.09745]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyspark\n",
    "import numpy as np\n",
    "from pyspark.sql import SparkSession\n",
    "from pyspark.sql.functions import *\n",
    "spark = SparkSession.builder.getOrCreate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: wget in /Users/joshuaallen/opt/anaconda3/lib/python3.8/site-packages (3.2)\r\n"
     ]
    }
   ],
   "source": [
    "from os import path\n",
    "if not path.exists('clean_askreddit.csv'):\n",
    "    if not path.exists('clean_askreddit.csv'):\n",
    "        !pip install wget\n",
    "        import wget\n",
    "        zip_path = 'https://dp-test-datasets.s3.amazonaws.com/clean_askreddit.csv.zip'\n",
    "        wget.download(zip_path)\n",
    "    import zipfile\n",
    "    with zipfile.ZipFile('clean_askreddit.csv.zip', 'r') as zip:\n",
    "        zip.extractall('.')\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "filepath = \"clean_askreddit.csv\"\n",
    "reddit = spark.read.load(filepath, format=\"csv\", sep=\",\",inferSchema=\"true\", header=\"true\").dropna()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare Data for Processing\n",
    "\n",
    "Load the data from file and tokenize.  This code can be any caller-specific tokenization routine, and is independent of differential privacy.  Output RDD should include one list of tokens per row, but can have multiple rows per user, and does not need to be odered in any way.  This stage can be combined with other n_grams (e.g. 2-grams, 3-grams) and persisted to feed to DPSU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import nltk\n",
    "\n",
    "n_grams = 1\n",
    "distinct = True\n",
    "\n",
    "def tokenize(user_post):\n",
    "    user, post = user_post\n",
    "    tokens = post.split(\" \")\n",
    "    if n_grams > 1:\n",
    "        tokens = list(nltk.ngrams(tokens, n_grams))\n",
    "        tokens = [\"_\".join(g) for g in tokens]\n",
    "    if distinct:\n",
    "        tokens = list(set(tokens))\n",
    "    return (user, tokens)\n",
    "        \n",
    "tokenized = reddit.select(\"author\", \"clean_text\").rdd.map(tokenize).persist()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Instantiate DPSU Processor\n",
    "\n",
    "Create the object and pass in the privacy parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Params Delta_0=500, delta=4.54e-05, l_param=0.3333333333333333, l_rho=5.175812754165355, Gamma=6.842479420832021\n"
     ]
    }
   ],
   "source": [
    "from policy_laplace import PolicyLaplace\n",
    "\n",
    "epsilon = 3.0\n",
    "delta = np.exp(-10)\n",
    "alpha = 5.0\n",
    "tokens_per_user = 500\n",
    "prune_tail_below = None\n",
    "num_partitions = 1\n",
    "\n",
    "pl = PolicyLaplace(epsilon, delta, alpha, tokens_per_user, prune_tail_below, num_partitions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prune the tail\n",
    "pruned = pl.prune_tail(tokenized)\n",
    "\n",
    "# reservoir sample the input tokens\n",
    "sampled = pl.reservoir_sample(pruned, distinct).persist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "counted = pl.process_partitions(sampled)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Retrieved 13210 words from 150588\n",
      "['ca', 'changed', 'reddit', 'america', 'oh']\n"
     ]
    }
   ],
   "source": [
    "good = counted.filter(lambda row: pl.exceeds_threshold(row[1])).map(lambda row: row[0])\n",
    "\n",
    "print(\"Retrieved {0} words from {1}\".format(good.count(),counted.count()))\n",
    "print(good.take(5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
